{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ebUMqK9mGIDm"
   },
   "source": [
    "## The basics: interactive NumPy on GPU and TPU\n",
    "\n",
    "---\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "27TqNtiQF97X"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import random\n",
    "\n",
    "key = random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cRWoxSCNGU4o"
   },
   "outputs": [],
   "source": [
    "key, subkey = random.split(key)\n",
    "x = random.normal(key, (5000, 5000))\n",
    "\n",
    "print(x.shape)\n",
    "print(x.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "diPllsvgGfSA"
   },
   "outputs": [],
   "source": [
    "y = jnp.dot(x, x)\n",
    "print(y[0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8-psauxnGiRk"
   },
   "outputs": [],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-2FMQ8UeoTJ8"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DRnwCKFuGk8P"
   },
   "outputs": [],
   "source": [
    "jnp.dot(x, x.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "z4VX5PkMHJIu"
   },
   "outputs": [],
   "source": [
    "print(jnp.dot(x, 2 * x)[[0, 2, 1, 0], ..., None, ::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ORZ9Odu85BCJ"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "x_cpu = np.array(x)\n",
    "%timeit -n 1 -r 1 np.dot(x_cpu, x_cpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5BKh0eeAGvO5"
   },
   "outputs": [],
   "source": [
    "%timeit -n 5 -r 5 jnp.dot(x, x).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fm4Q2zpFHUAu"
   },
   "source": [
    "## Automatic differentiation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MCIQbyUYHWn1"
   },
   "outputs": [],
   "source": [
    "from jax import grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kfqZpKYsHo4j"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  if x > 0:\n",
    "    return 2 * x ** 3\n",
    "  else:\n",
    "    return 3 * x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K_26_odPHqLJ"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(0)\n",
    "x = random.normal(key, ())\n",
    "\n",
    "print(grad(f)(x))\n",
    "print(grad(f)(-x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "q5V3A6loHrhS"
   },
   "outputs": [],
   "source": [
    "print(grad(grad(f))(-x))\n",
    "print(grad(grad(grad(f)))(-x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ba4WY4ArHv8I"
   },
   "outputs": [],
   "source": [
    "def predict(params, inputs):\n",
    "  for W, b in params:\n",
    "    outputs = jnp.dot(inputs, W) + b\n",
    "    inputs = jnp.tanh(outputs)  # inputs to the next layer\n",
    "  return outputs                # no activation on last layer\n",
    "\n",
    "def loss(params, batch):\n",
    "  inputs, targets = batch\n",
    "  predictions = predict(params, inputs)\n",
    "  return jnp.sum((predictions - targets)**2)\n",
    "\n",
    "\n",
    "\n",
    "def init_layer(key, n_in, n_out):\n",
    "  k1, k2 = random.split(key)\n",
    "  W = random.normal(k1, (n_in, n_out))\n",
    "  b = random.normal(k2, (n_out,))\n",
    "  return W, b\n",
    "\n",
    "layer_sizes = [5, 2, 3]\n",
    "\n",
    "key = random.PRNGKey(0)\n",
    "key, *keys = random.split(key, len(layer_sizes))\n",
    "params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
    "\n",
    "key, *keys = random.split(key, 3)\n",
    "inputs = random.normal(keys[0], (8, 5))\n",
    "targets = random.normal(keys[1], (8, 3))\n",
    "batch = (inputs, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LiTBibJdHz4K"
   },
   "outputs": [],
   "source": [
    "print(loss(params, batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "a3KFpwH3H4Cl"
   },
   "outputs": [],
   "source": [
    "step_size = 1e-2\n",
    "\n",
    "for _ in range(20):\n",
    "  grads = grad(loss)(params, batch)\n",
    "  params = [(W - step_size * dW, b - step_size * db)\n",
    "            for (W, b), (dW, db) in zip(params, grads)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YLltDr0GH7LX"
   },
   "outputs": [],
   "source": [
    "print(loss(params, batch))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bmxAPFC0I8b0"
   },
   "source": [
    "Other JAX autodiff highlights:\n",
    "\n",
    "*   Forward- and reverse-mode, totally composable\n",
    "*   Fast Jacobians and Hessians\n",
    "*   Complex number support (holomorphic and non-holomorphic)\n",
    "*   Jacobian pre-accumulation for elementwise operations (like `gelu`)\n",
    "\n",
    "\n",
    "For much more, see the [JAX Autodiff Cookbook (Part 1)](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TRkxaVLJKNre"
   },
   "source": [
    "## End-to-end compilation with XLA using `jit`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "bKo4rX9-KSW7"
   },
   "outputs": [],
   "source": [
    "from jax import jit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "94iIgZSfKWh8"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(0)\n",
    "x = random.normal(key, (5000, 5000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Ybuz8Ag9KXMd"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = x\n",
    "  for _ in range(10):\n",
    "    y = y - 0.1 * y + 3.\n",
    "  return y[:100, :100]\n",
    "\n",
    "f(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Y9dx5ifSKaGJ"
   },
   "outputs": [],
   "source": [
    "g = jit(f)\n",
    "g(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "UtsS67BvKYkC"
   },
   "outputs": [],
   "source": [
    "%timeit f(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-vfcaSo9KbvR"
   },
   "outputs": [],
   "source": [
    "%timeit g(x).block_until_ready()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "E3BQF1_AKeLn"
   },
   "outputs": [],
   "source": [
    "grad(jit(grad(jit(grad(jnp.tanh)))))(1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AvXl1WDPKjmV"
   },
   "source": [
    "### Constraints that come with using `jit`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mCtwRF18KnsE"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  if x > 0:\n",
    "    return 2 * x ** 2\n",
    "  else:\n",
    "    return 3 * x\n",
    "\n",
    "g = jit(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_82tY-ZSKqv4"
   },
   "outputs": [],
   "source": [
    "f(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TjSAFc-iKrcB"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    g(2)\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RhizP9pjKsug"
   },
   "outputs": [],
   "source": [
    "def f(x, n):\n",
    "  i = 0\n",
    "  while i < n:\n",
    "    x = x * x\n",
    "    i += 1\n",
    "  return x\n",
    "\n",
    "g = jit(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Wn6haTmUK-Q8"
   },
   "outputs": [],
   "source": [
    "f(jnp.array([1., 2., 3.]), 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HwBy1I04K-81"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    g(jnp.array([1., 2., 3.]), 5)\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XmaTryZaK_3M"
   },
   "outputs": [],
   "source": [
    "g = jit(f, static_argnums=(1,))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HcWjxVktV4fa"
   },
   "outputs": [],
   "source": [
    "g(jnp.array([1., 2., 3.]), 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0M_-pJe7LOcO"
   },
   "source": [
    "## Vectorization with `vmap`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8XIot_ndLRH1"
   },
   "outputs": [],
   "source": [
    "from jax import vmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tRvCZn2wBkXP"
   },
   "outputs": [],
   "source": [
    "print(vmap(lambda x: x**2)(jnp.arange(8)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "icfsXizI_rkD"
   },
   "outputs": [],
   "source": [
    "from jax import make_jaxpr\n",
    "\n",
    "make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uQm4cvAbA6M3"
   },
   "outputs": [],
   "source": [
    "make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "NeiFfCHEBLsU"
   },
   "outputs": [],
   "source": [
    "make_jaxpr(vmap(vmap(jnp.dot)))(jnp.ones((10, 10, 8)), jnp.ones((10, 10, 8)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "csX71fkSCZrp"
   },
   "outputs": [],
   "source": [
    "perex_grads = vmap(grad(loss), in_axes=(None, 0))\n",
    "make_jaxpr(perex_grads)(params, batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Tmf1NT2Wqv5p"
   },
   "source": [
    "## Parallel accelerators with pmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "t6RRAFn1CEln"
   },
   "outputs": [],
   "source": [
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tEK1I6Duqunw"
   },
   "outputs": [],
   "source": [
    "from jax import pmap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "S-iCNfeGqzkY"
   },
   "outputs": [],
   "source": [
    "y = pmap(lambda x: x ** 2)(jnp.arange(8))\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xgutf5JPP3wi"
   },
   "outputs": [],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xxShG3Tdq4Gj"
   },
   "outputs": [],
   "source": [
    "z = y / 2\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uvDL2_bCq7kq"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Xg76CmLYq_Q6"
   },
   "outputs": [],
   "source": [
    "keys = random.split(random.PRNGKey(0), 8)\n",
    "mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n",
    "result = pmap(jnp.dot)(mats, mats)\n",
    "print(pmap(jnp.mean)(result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jbw_hRx7rDzX"
   },
   "outputs": [],
   "source": [
    "timeit -n 5 -r 5 pmap(jnp.dot)(mats, mats).block_until_ready()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xf5N9ZRirJhL"
   },
   "source": [
    "### Collective communication operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9i1PfxUvrThh"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from jax.lax import psum\n",
    "\n",
    "@partial(pmap, axis_name='i')\n",
    "def normalize(x):\n",
    "  return x / psum(x, 'i')\n",
    "\n",
    "print(normalize(jnp.arange(8.)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lnvwnlOFrVa-"
   },
   "outputs": [],
   "source": [
    "@partial(pmap, axis_name='rows')\n",
    "@partial(pmap, axis_name='cols')\n",
    "def f(x):\n",
    "  row_sum = psum(x, 'rows')\n",
    "  col_sum = psum(x, 'cols')\n",
    "  total_sum = psum(x, ('rows', 'cols'))\n",
    "  return row_sum, col_sum, total_sum\n",
    "\n",
    "x = jnp.arange(8.).reshape((4, 2))\n",
    "a, b, c = f(x)\n",
    "\n",
    "print(\"input:\\n\", x)\n",
    "print(\"row sum:\\n\", a)\n",
    "print(\"col sum:\\n\", b)\n",
    "print(\"total sum:\\n\", c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f-FBsWeo1AXE"
   },
   "source": [
    "<img src=\"https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/nested_pmap.png\" width=\"70%\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jC-KIMQ1q-lK"
   },
   "source": [
    "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-A-oVDo6rdWA"
   },
   "source": [
    "### Compose pmap with other transforms!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WC_dMIN2rgTZ"
   },
   "outputs": [],
   "source": [
    "@pmap\n",
    "def f(x):\n",
    "  y = jnp.sin(x)\n",
    "  @pmap\n",
    "  def g(z):\n",
    "    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n",
    "  return grad(lambda w: jnp.sum(g(w)))(x)\n",
    "  \n",
    "f(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "apuACjPWrixV"
   },
   "outputs": [],
   "source": [
    "grad(lambda x: jnp.sum(f(x)))(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WD9xtROsYX4i"
   },
   "source": [
    "### Compose everything"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "h65c9AQCWAyn"
   },
   "outputs": [],
   "source": [
    "from jax import jvp, vjp  # forward and reverse-mode\n",
    "\n",
    "curry = lambda f: partial(partial, f)\n",
    "\n",
    "@curry\n",
    "def jacfwd(fun, x):\n",
    "  pushfwd = partial(jvp, fun, (x,))  # jvp!\n",
    "  std_basis = jnp.eye(np.size(x)).reshape((-1,) + jnp.shape(x)),\n",
    "  y, jac_flat = vmap(pushfwd, out_axes=(None, -1))(std_basis)  # vmap!\n",
    "  return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n",
    "\n",
    "@curry\n",
    "def jacrev(fun, x):\n",
    "  y, pullback = vjp(fun, x)  # vjp!\n",
    "  std_basis = jnp.eye(np.size(y)).reshape((-1,) + jnp.shape(y))\n",
    "  jac_flat, = vmap(pullback)(std_basis)  # vmap!\n",
    "  return jac_flat.reshape(jnp.shape(y) + jnp.shape(x))\n",
    "\n",
    "def hessian(fun):\n",
    "  return jit(jacfwd(jacrev(fun)))  # jit!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "G9qDX84RWhW7"
   },
   "outputs": [],
   "source": [
    "input_hess = hessian(lambda inputs: loss(params, (inputs, targets)))\n",
    "per_example_hess = pmap(input_hess)  # pmap!\n",
    "per_example_hess(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "u3ggM_WYZ8QC"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [
    "AvXl1WDPKjmV"
   ],
   "name": "JAX demo.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
